BONN: Bayesian Optimized Binary Neural Network
75
Algorithm 6 Pruning 1-bit CNNs with Bayesian learning
Input:
The pre-trained 1-bit CNN model with parameters K, the reconstruction vector w, the learning
rate η, regularization parameters λ, θ, variance ν and convergence rate γ and the training
dataset.
Output:
The pruned BONN with updated K, w, μ, σ, cm, σm.
1: repeat
2:
// Forward propagation
3:
for l = 1 to L do
4:
Kl
i,j = (1 −γ)Kl
i,j + γK
l
j;
5:
ˆkl
i = wl ◦sign(kl
i), ∀i; // Each element of wl is replaced by the average of all elements wl.
6:
Perform activation binarization; // Using the sign function
7:
Perform 2D convolution with ˆkl
i, ∀i;
8:
end for
9:
// Backward propagation
10:
Compute δˆkl
i = ∂Ls
∂ˆkl
i , ∀l, i;
11:
for l = L to 1 do
12:
Calculate δkl
i, δwl, δμl
i, δσl
i; // using Eqs. 3.115∼3.120
13:
Update parameters kl
i, wl, μl
i, σl
i using SGD;
14:
end for
15:
Update cm, σm;
16: until Filters in the same group are similar enough
Updating Kl
i,j: In pruning, we aim to converge the filters to their mean gradually. So
we replace each filter Kl
i,j with its corresponding mean K
l
i,j. The gradient of the mean is
represented as follows:
∂L
∂Kl
i,j
= ∂LS
∂Kl
i,j
+ ∂LB
∂Kl
i,j
+ ∂LP
∂Kl
i,j
= ∂LS
∂K
l
j
∂K
l
j
∂Kl
i,j
+ ∂LB
∂K
l
j
∂K
l
j
∂Kl
i,j
+ ∂LP
∂Kl
i,j
= 1
Ij
∂LS
∂K
l
j
+ ∂LB
∂K
l
j
+ 2(Kl
i,j−Kj)
+ 2ν(Ψl
j)−1(Kl
i,j−Kj),
(3.120)
where K
l
j =
1
Ij
Ij
i=1 Kl
i,j that is used to update the filters in a group by mean K
l
j. We
leave the first filter in each group to prune redundant filters and remove the others. However,
such an operation changes the distribution of the input channel of the batch norm layer,
resulting in a dimension mismatch for the next convolutional layer. To solve the problem,
we keep the size of the batch norm layer, whose values correspond to the removed filters, set
to zero. In this way, the removed information is retained to the greatest extent. In summary,
we show that the proposed method is trainable from end to end. The learning procedure is
detailed in Algorithms 5 and 6.